import torch
import warnings
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, sampler
from torch.utils.data.distributed import DistributedSampler

from .imb_dataset import *
from .ssl_dataset import *
from .lnl_datasets import *

import logging
logger = logging.getLogger(__name__)

def get_ssl_dataloader(args):
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()
    labeled_dataset, unlabeled_dataset, test_dataset = DATASET_GETTERS[args.dataset](args, '../database/' + args.dataset.upper())
    if args.local_rank == 0:
        torch.distributed.barrier()
    
    train_sampler = RandomSampler if args.local_rank == -1 else DistributedSampler

    labeled_trainloader = DataLoader(
        labeled_dataset,
        sampler=train_sampler(labeled_dataset),
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        drop_last=True
    )
    unlabeled_trainloader = DataLoader(
        unlabeled_dataset,
        sampler=train_sampler(unlabeled_dataset),
        batch_size=args.batch_size * args.mu,
        num_workers=args.num_workers,
        drop_last=True
    )
    test_loader = DataLoader(
        test_dataset,
        sampler=SequentialSampler(test_dataset),
        batch_size=args.batch_size,
        num_workers=args.num_workers
    )
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()
    
    return labeled_trainloader, unlabeled_trainloader, test_loader

def get_lnl_dataloader(args):
    if args.dataset.lower() in ['mnist', 'cifar10', 'cifar100']:
        data_loader = DatasetGenerator(
            batch_size=args.batch_size,
            data_path=args.root,
            seed=args.seed,
            num_workers=args.num_workers,
            is_asym=args.noise_type=='asymmetric',
            dataset = args.dataset,
            noise_rate=args.noise_rate
        )
    elif args.dataset.lower() == 'clothing1m':
        data_loader = Clothing1MDatasetLoader(
            data_path=args.root,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
        )
    elif args.dataset.lower() == 'webvision':
        data_loader = WebVisionDatasetLoader(
            train_batch_size=args.batch_size,
            eval_batch_size=args.batch_size,
            train_data_path=os.path.join(args.root, 'WebVision'),
            valid_data_path=os.path.join(args.root, 'WebVision'),
            num_of_workers=args.num_workers
        )
    else:
        raise NotImplementedError
    data_loader = data_loader.getDataLoader()
    return data_loader

def get_vic_dataloader(args):
    if args.dataset in ['mnist', 'MNIST']:
        MEAN = [0.1307]
        STD = [0.3081]
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(MEAN, STD)])
        train_dataset = datasets.MNIST(root=os.path.join(args.root, 'MNIST'), train=True, download=True, transform=transform)
        val_dataset = datasets.MNIST(root=os.path.join(args.root, 'MNIST'), train=False, download=True, transform=transform)
    elif args.dataset in ['cifar10', 'CIFAR10']:
        CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
        CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(CIFAR_MEAN, CIFAR_STD)])
        transform_val = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(CIFAR_MEAN, CIFAR_STD)])
        train_dataset = datasets.CIFAR10(root=os.path.join(args.root, 'CIFAR10'), train=True, download=True, transform=transform_train)
        val_dataset = datasets.CIFAR10(root=os.path.join(args.root, 'CIFAR10'), train=False, download=True, transform=transform_val)
    elif args.dataset in ['cifar100', 'CIFAR100']:
        CIFAR_MEAN = [0.5071, 0.4865, 0.4409]
        CIFAR_STD = [0.2673, 0.2564, 0.2762]
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(20),
            transforms.ToTensor(),
            transforms.Normalize(CIFAR_MEAN, CIFAR_STD)])

        transform_val = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(CIFAR_MEAN, CIFAR_STD)])
        train_dataset = datasets.CIFAR100(root=os.path.join(args.root, 'CIFAR100'), train=True, download=True,
                                            transform=transform_train)
        val_dataset = datasets.CIFAR100(root=os.path.join(args.root, 'CIFAR100'), train=False, download=True, transform=transform_val)
    else:
        warnings.warn('Dataset is not listed!')
    data_loaders = {}
    data_loaders['train'] = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.num_workers, pin_memory=True
    )
    data_loaders['val'] = torch.utils.data.DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=args.num_workers, pin_memory=True
    )
    logger.info("Num of train %d" % (len(train_dataset)))
    logger.info("Num of test %d" % (len(val_dataset)))
    return data_loaders

def get_imb_dataloader(args):
    if args.dataset.lower() == 'cifar10':
        dataset = CIFAR10_LT(args.distributed, root=os.path.join(args.root, 'CIFAR10'), imb_type=args.imb_type, imb_factor=args.imb_factor, batch_size=args.batch_size, num_works=args.num_workers)
    elif args.dataset.lower() == 'cifar100':
        dataset = CIFAR100_LT(args.distributed, root=os.path.join(args.root, 'CIFAR100'), imb_type=args.imb_type, imb_factor=args.imb_factor, batch_size=args.batch_size, num_works=args.num_workers)
    elif args.dataset.lower() == 'places':
        dataset = Places_LT(args.distributed, root=args.root, batch_size=args.batch_size, num_works=args.num_workers)

    elif args.dataset.lower() == 'imagenet':
        dataset = ImageNet_LT(args.distributed, root=args.root, batch_size=args.batch_size, num_works=args.num_workers)

    elif args.dataset.lower() == 'ina2018':
        dataset = iNa2018(args.distributed, root=args.root, batch_size=args.batch_size, num_works=args.num_workers)

    return dataset